from sqlalchemy import create_engine
from sqlalchemy.ext.declarative import declarative_base
from sqlalchemy.orm import sessionmaker, scoped_session

from python_quant.config.sys_config import SQLAlchemyConfig

# 创建数据库引擎
engine = create_engine(
    SQLAlchemyConfig.SQLALCHEMY_DATABASE_URI,  # SQLAlchemy 数据库连接串，格式见下面
    echo=bool(SQLAlchemyConfig.SQLALCHEMY_ECHO),  # 是不是要把所执行的SQL打印出来，一般用于调试
    pool_size=int(SQLAlchemyConfig.SQLALCHEMY_POOL_SIZE),  # 连接池大小
    max_overflow=int(SQLAlchemyConfig.SQLALCHEMY_POOL_MAX_SIZE),  # 连接池最大的大小
    pool_recycle=int(SQLAlchemyConfig.SQLALCHEMY_POOL_RECYCLE),  # 多久时间主动回收连接，见下注释
)
# 创建 session 工厂
Session = scoped_session(sessionmaker(bind=engine, autocommit=False, autoflush=False))
# 创建模型基类
Base_Decla = declarative_base()


# 自动建表
def create_all_tables():
    Base_Decla.metadata.create_all(engine)


def get_session():
    return Session()


# 增加记录
def add_record(obj):
    session = Session()
    session.add(obj)
    session.commit()
    session.close()


# 删除记录
def delete_record(obj):
    session = Session()
    session.delete(obj)
    session.commit()
    session.close()


# 更新记录
def update_record(obj):
    session = Session()
    session.merge(obj)
    session.commit()
    session.close()


# 查询记录
def query_record(cls, **kwargs):
    session = Session()
    query = session.query(cls)
    for key, value in kwargs.items():
        query = query.filter(getattr(cls, key) == value)
    result = query.all()
    session.close()
    return result
